# Calculating excess load based on OGE BA data
library(pacman)
p_load(
  fst, data.table, here, purrr, stringr, lubridate, 
  collapse, haven, readxl
)
# TODO: 2020 was a leap year

# Function to calculate line losses -------------------------------------------
calculate_line_losses = function(nerc_load_dt){
  # Data from borenstein bushnell
  bb_dt = read_dta(here('Data/SeverinMC/output.dta')) |> data.table()
  # Crosswalk between utilities and regions 
  utility_nerc_xwalk = 
    rbind(
      # First 2014 data
      data.table(read_xls(
        here('Data/electricity-generation/eia861/f8612014/Utility_Data_2014.xls'),
        sheet = 'States',
        skip = 1
      )),
      # Now 2015 data
      data.table(read_xlsx(
        here('Data/electricity-generation/eia861/f8612015/Utility_Data_2015.xlsx'),
        sheet = 'States',
        skip = 1
      )),
      # Finally 2016 data
      data.table(read_xlsx(
        here('Data/electricity-generation/eia861/f8612016/Utility_Data_2016.xlsx'),
        sheet = 'States',
        skip = 1
      ))
    )[,.(
      year = `Data Year`,
      eia_id_d = `Utility Number`,
      state = State, 
      in_miso = MISO,
      in_pjm = PJM,
      in_ercot = ERCOT,
      in_caiso = CAISO,
      nerc = `NERC Region`    
    )]
  # Doing some cleaning--not perfect
  utility_nerc_xwalk[,
    nerc_adj := fcase(
      in_miso == 'Y', 'MRO',
      in_pjm == 'Y', 'RFC',
      in_caiso == 'Y', 'CAL',
      state == 'CA', 'CAL',
      nerc == 'FRCC', 'SERC',
      nerc == 'ECAR' & state == 'KY', 'SERC',
      nerc == 'ERCOT', 'TRE',
      nerc == 'MRO/WECC', 'MRO',
      nerc == 'NY', 'NPCC',
      nerc == 'SPP', 'MRO',
      nerc == '25470', 'MRO',
      nerc %in% c('AK','HI','HICC'), NA_character_,
      !is.na(nerc), toupper(nerc)
    )
  ]
  # Calculating alphas as percentage for each region
  loss_dt = 
    merge(
      bb_dt, 
      utility_nerc_xwalk,
      by = c('year','eia_id_d')
    )[!is.na(nerc_adj),.(
      avg_pct_losses = weighted.mean(lossresdist, w = res_sales)/100), 
      keyby = nerc_adj
    ]
  # Calculating loss parameters
  alpha_dt = 
    merge(
      nerc_load_dt[,.(
        T = .N,
        tot_load = sum(load),
        tot_load_sq = sum(load^2)), 
        keyby = nerc_adj
      ],
      loss_dt, 
      by = 'nerc_adj'
    )[,.(
      nerc_adj,
      alpha_1 = 0.25*avg_pct_losses*tot_load/T,
      alpha_2 = 0.75*avg_pct_losses*tot_load/tot_load_sq
    )]
  fwrite(
    alpha_dt, 
    file = here('Data/electricity-generation/line-loss-alphas.csv')
  )
  # Adjusting load and excesss load by losses
  nerc_load_dt = 
    merge(
      nerc_load_dt[,-c('loss','load_loss_adj','excess_load_loss_adj')], 
      alpha_dt, 
      by = 'nerc_adj',
      all.x = TRUE
    )[,':='(
      tot_loss = alpha_1 + alpha_2*(load^2),
      load_loss_adj = load - (alpha_1 + alpha_2*(load^2)),
      excess_load_loss_adj = excess_load - (alpha_1 + alpha_2*(excess_load^2)),
      load_ff_loss_adj = load_ff - (alpha_1 + alpha_2*(load_ff^2)),
      load_nd_loss_adj = load_nd - (alpha_1 + alpha_2*(load_nd^2)),
      load_solar_loss_adj = load_solar - (alpha_1 + alpha_2*(load_solar^2)),
      load_hydro_loss_adj = load_hydro - (alpha_1 + alpha_2*(load_hydro^2)),
      load_coal_loss_adj = load_coal - (alpha_1 + alpha_2*(load_coal^2)),
      load_gas_loss_adj = load_gas - (alpha_1 + alpha_2*(load_gas^2))
    )]
  return(nerc_load_dt)
}


excess_load_region = function(){
# Loading data ----------------------------------------------------------------
  print('loading data')
  # Plant info table   
  oge_plant_info_dt =
    read.fst(
      path = here("Data/electricity-generation/plant-info-dt-oge.fst"),
      as.data.table = TRUE
    )
  # Loading the balancing authority level data
  oge_ba_dt = 
    rbind(
      map_dfr(
        list.files(here(
          'Data/electricity-generation/open-grid-emissions/2019_power_sector_data_hourly_us_units'
        )),
        \(ba_in){
          fread(here(paste0(
            'Data/electricity-generation/open-grid-emissions/2019_power_sector_data_hourly_us_units/',
            ba_in
          )))[,ba_code := str_remove(ba_in, '\\.csv')][,year := 2019]
        }
      ),
      map_dfr(
        list.files(here(
          'Data/electricity-generation/open-grid-emissions/2020_power_sector_data_hourly_us_units'
        )),
        \(ba_in){
          fread(here(paste0(
            'Data/electricity-generation/open-grid-emissions/2020_power_sector_data_hourly_us_units/',
            ba_in
          )))[,ba_code := str_remove(ba_in, '\\.csv')][,year := 2020]
        }
      )
    ) |> setkey(ba_code, fuel_category, datetime_utc)
  # A few hours have data reported in both years 
  oge_ba_dt = oge_ba_dt[,.(
    net_generation_mwh = sum(net_generation_mwh, na.rm = TRUE)), 
    keyby = .(ba_code, fuel_category, datetime_utc)
  ]
  # Checking for outliers 
  #p_load(ggplot2)
  #ggplot(
  #  oge_ba_dt[
  #    ba_code %in% c('ISNE','NBSO','NYIS','RIMS')
  #    #ba_code %in% c('WWA','GWA')
  #  ], 
  #  aes(x = datetime_utc, y = net_generation_mwh_clean, color = fuel_category)
  #)+ 
  #  geom_line() + 
  #  facet_wrap(~ba_code, scales = 'free')
  # Calculating outlier threshold
  ba_fuel_type_outliers = 
    oge_ba_dt[,.(
      max_net_generation_mwh = quantile(net_generation_mwh, probs = 0.99)), 
      keyby = .(ba_code, fuel_category)
    ]
  # Adding lags to the data
  oge_ba_dt =
    oge_ba_dt |>
    gby(ba_code, fuel_category) |>
    flag(t = datetime_utc, n = c(0,1,24)) |>
    fungroup()
  # Censoring---Setting to 1h or 24h lag if above 1.5*99th percentile
  oge_ba_dt = 
    merge(
      oge_ba_dt, 
      ba_fuel_type_outliers, 
      by = c('ba_code','fuel_category')
    )[,':='(
      net_generation_mwh_clean = fcase(
        net_generation_mwh >= 0 & 
        net_generation_mwh <= 1.5*max_net_generation_mwh, 
          net_generation_mwh,
        !is.na(L1.net_generation_mwh) & 
        L1.net_generation_mwh > 0 & 
        L1.net_generation_mwh < 1.5*max_net_generation_mwh, 
          L1.net_generation_mwh,
        !is.na(L24.net_generation_mwh) & 
        L24.net_generation_mwh > 0 & 
        L24.net_generation_mwh< 1.5*max_net_generation_mwh, 
          L24.net_generation_mwh,
        !is.na(max_net_generation_mwh), 
          1.5*max_net_generation_mwh,
        default = 0
      ),
      imp_flag = fcase(
        net_generation_mwh >= 0 & 
          net_generation_mwh <= 1.5*max_net_generation_mwh, 
        0,
        !is.na(L1.net_generation_mwh) & 
          L1.net_generation_mwh >= 0 & 
          L1.net_generation_mwh <= 1.5*max_net_generation_mwh, 
        1,
        !is.na(L24.net_generation_mwh) & 
          L24.net_generation_mwh >= 0 & 
          L24.net_generation_mwh <= 1.5*max_net_generation_mwh, 
        1,
        !is.na(max_net_generation_mwh), 
        1,
        default = 0
      ),
      max_net_generation_mwh = NULL,
      L1.net_generation_mwh = NULL,
      L24.net_generation_mwh = NULL
    )]
  print(oge_ba_dt[,.(`Percent imputed` = mean(imp_flag))])

# Calculating load ------------------------------------------------------------
  print('calculating load')
  oge_ba_load_dt = 
    oge_ba_dt |>
    dcast(
      formula = ba_code + datetime_utc ~ fuel_category,
      value.var = 'net_generation_mwh_clean',
      fill = 0
    ) %>% .[,.(
        ba_code, datetime_utc,
        load = total,
        excess_load = total - wind - solar,
        load_nd = wind + solar,
        load_ff = coal + natural_gas + petroleum,
        load_solar = solar,
        load_hydro = hydro,
        load_coal = coal,
        load_gas = natural_gas
      )
    ]
  # Making it a balanced panel
  all_ba_hour_dt = 
    CJ(
      ba_code = unique(oge_ba_dt$ba_code),
      datetime_utc = seq(
        min(oge_ba_dt$datetime_utc),
        max(oge_ba_dt$datetime_utc), 
        by = 'hour'
      )
    ) |>
    setkey(ba_code, datetime_utc)
  oge_ba_load_dt = 
    merge(
      all_ba_hour_dt,
      oge_ba_load_dt, 
      by = c("ba_code","datetime_utc"),
      all.x = TRUE
    ) 
  print(head(
    oge_ba_load_dt[,.(`Missing Load` = mean(is.na(load))), keyby = datetime_utc],
    n = 20
  ))
  print(tail(
    oge_ba_load_dt[,.(`Missing Load` = mean(is.na(load))), keyby = datetime_utc],
    n = 20
  ))
  # setting na to zero
  oge_ba_load_dt = 
    cbind(
      oge_ba_load_dt[,.(ba_code, datetime_utc)],
      oge_ba_load_dt[,
        lapply(.SD, \(x){fifelse(is.na(x),0,x)}),
        .SDcols = colnames(oge_ba_load_dt)[-(1:2)]
      ]
    ) %>%
    .[datetime_utc %between% c('2019-01-01 10:00:00', '2021-01-01 03:00:00')]
  # Saving balancing authority load data
  write.fst(
    oge_ba_load_dt,
    here("Data/electricity-generation/clean-load/ba-load-dt-oge.fst")
  )
# Aggregating BA to region ---------------------------------------------------
  ba_nerc_xwalk = fread(here('Data/electricity-generation/ba-nerc-xwalk-oge.csv'))
  # Aggregating to nerc region
  oge_nerc_load_dt =
    merge(
      oge_ba_load_dt,
      ba_nerc_xwalk,
      by = c("ba_code"),
      all.x = TRUE
    )[,
      lapply(.SD, sum, na.rm = TRUE),
      keyby = .(nerc_adj, datetime_utc),
      .SDcols = str_subset(colnames(oge_ba_load_dt), "load")
    ]
  # Checking that it is a balanced panel 
  nhour = seq(
    ymd_hms('2019-01-01 10:00:00'), 
    ymd_hms('2021-01-01 03:00:00'), 
    by = 'hour'
  ) |> length()
  print(paste0('Expect ', nhour,' obs per nerc adj.'))
  print(oge_nerc_load_dt[,.N,by = nerc_adj])
  # Adjusting for line losses 
  oge_nerc_load_dt = calculate_line_losses(oge_nerc_load_dt)
  # Saving the results 
  write.fst(
    x = oge_nerc_load_dt,
    path = here("Data/electricity-generation/clean-load/nerc-load-dt-oge.fst")
  )
  print('done.')
}
  #p_load(ggplot2)  
  #ggplot(oge_nerc_load_dt, aes(x = datetime_utc, y = excess_load)) + 
  #  geom_line() + 
  #  facet_wrap(~nerc_adj, scales = 'free')
  
oge_nerc_load_dt = 
read.fst(
  path = here("Data/electricity-generation/clean-load/nerc-load-dt-oge.fst"),
  as.data.table = TRUE
)
  